{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# R-CNN Architecture for Decoding EEG MI Data "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import module "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import os; os.environ[\"THEANO_FLAGS\"] = \"device=gpu0\"\n",
    "import os.path\n",
    "from datetime import datetime\n",
    "import sys\n",
    "sys.path.append('../../gumpy')\n",
    "\n",
    "import gumpy\n",
    "import numpy as np\n",
    "import scipy.io\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "To use the models provided by gumpy-deeplearning, we have to set the path to the models directory and import it. If you installed gumpy-deeplearning as a module, this step may not be required.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.append('..')\n",
    "import models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "The examples for gumpy-deeplearning ship with a few tiny helper functions. For instance, there's one that tells you the versions of the currently installed keras and kapre. keras is required in gumpy-deeplearning. \n",
    "In addition, the utility functions contain a method load_preprocess_data to load and preprocess data. Its usage will be shown further below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import utils\n",
    "utils.print_version_info()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup parameters for the model and data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "DEBUG = True\n",
    "######\n",
    "## the RCNN FLAG because It does not use spectrograms but rather it uses directly the raw signals.\n",
    "# so the data have something different compared to the spectrogram so you have to activte the flage.\n",
    "######\n",
    "RCNN_FLAG = True\n",
    "###########\n",
    "\n",
    "CLASS_COUNT = 2\n",
    "DROPOUT = 0.2   # dropout rate in float\n",
    "\n",
    "# parameters for filtering data\n",
    "FS = 250\n",
    "LOWCUT = 2\n",
    "HIGHCUT = 60\n",
    "ANTI_DRIFT = 0.5\n",
    "CUTOFF = 50.0 # freq to be removed from signal (Hz) for notch filter\n",
    "Q = 30.0  # quality factor for notch filter\n",
    "W0 = CUTOFF/(FS/2)\n",
    "AXIS = 0\n",
    "\n",
    "#set random seed\n",
    "SEED = 42\n",
    "KFOLD = 5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load raw data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = '../../grazdata'\n",
    "subject = 'B01'\n",
    "\n",
    "# initialize the data-structure, but do _not_ load the data yet\n",
    "grazb_data = gumpy.data.GrazB(data_dir, subject)\n",
    "\n",
    "# now that the dataset is setup, we can load the data. This will be handled from within the utils function,\n",
    "# which will first load the data and subsequently filter it using a notch and a bandpass filter.\n",
    "# the utility function will then return the training data.\n",
    "x_train, y_train = utils.load_preprocess_data(grazb_data, True, LOWCUT, HIGHCUT, W0, Q, ANTI_DRIFT, CLASS_COUNT, CUTOFF, AXIS, FS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Augment data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_augmented, y_augmented = gumpy.signal.sliding_window(data = x_train[:,:,:],\n",
    "                                                          labels = y_train[:,:],\n",
    "                                                          window_sz = 4 * FS,\n",
    "                                                          n_hop = FS // 10,\n",
    "                                                          n_start = FS * 1)\n",
    "x_subject = x_augmented\n",
    "y_subject = y_augmented\n",
    "x_subject = np.rollaxis(x_subject, 2, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Run the model "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import StratifiedKFold\n",
    "from models import RCNN\n",
    "\n",
    "# define KFOLD-fold cross validation test harness\n",
    "kfold = StratifiedKFold(n_splits=KFOLD, shuffle=True, random_state=SEED)\n",
    "cvscores = []\n",
    "ii = 1\n",
    "for train, test in kfold.split(x_subject, y_subject[:, 0]):\n",
    "    print('Run ' + str(ii) + '...')\n",
    "    # create callbacks\n",
    "    model_name_str = 'GRAZ_CNN_STFT_3layer_' + \\\n",
    "                     '_run_' + str(ii)\n",
    "    callbacks_list = model.get_callbacks(model_name_str)\n",
    "    #print(x_subject.shape)\n",
    "    #print(train)\n",
    "    # initialize and create the model\n",
    "    model = RCNN(model_name_str)\n",
    "    model.create_model(x_subject.shape[1:], print_summary=False, class_count = CLASS_COUNT)\n",
    "\n",
    "    # fit model. If you specify monitor=True, then the model will create callbacks\n",
    "    # and write its state to a HDF5 file\n",
    "\n",
    "    if (RCNN_FLAG == True):\n",
    "        x_subject = np.rollaxis(x_subject, 2, 1)\n",
    "        x_subject = x_subject[:, np.newaxis, :, :]\n",
    "        #print(x_subject.shape)\n",
    "    model.fit(x_subject[train], y_subject[train], monitor=True,\n",
    "              epochs=100,\n",
    "              batch_size=256,\n",
    "              verbose=0,\n",
    "              validation_split=0.1,callbacks = callbacks_list)\n",
    "\n",
    "    # evaluate the model\n",
    "    print('Evaluating model on test set...')\n",
    "    scores = model.evaluate(x_subject[test], y_subject[test], verbose=0)\n",
    "    print(\"Result on test set: %s: %.2f%%\" % (model.metrics_names[1], scores[1] * 100))\n",
    "    cvscores.append(scores[1] * 100)\n",
    "    ii += 1\n",
    "\n",
    "# print some evaluation statistics and write results to file\n",
    "print(\"%.2f%% (+/- %.2f%%)\" % (np.mean(cvscores), np.std(cvscores)))\n",
    "cv_all_subjects = np.asarray(cvscores)\n",
    "print('Saving CV values to file....')\n",
    "np.savetxt('GRAZ_CV_' + 'CNN_STFT_3layer_' + str(DROPOUT) + 'do' + '.csv',\n",
    "           cv_all_subjects, delimiter=',', fmt='%2.4f')\n",
    "print('CV values successfully saved!\\n')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
